"""
探究存在多条transformation时，vae如何抉择。
将多条transformations合成一个数据集.
合成
0: y
1: x
2: diagonal
"""
import argparse
from collections import defaultdict

from torch import optim
from torch.utils.data import DataLoader
import wandb
from tqdm import trange

from disvae.models.anneal import get_anneal
from disvae.models.losses import get_loss_s
from disvae.models.vae import VAE
from disvae.models.encoders import get_encoder
from disvae.models.decoders import get_decoder
from utils.visualize import *
from exps.patterns import data_generator
import numpy as np

from utils.helpers import set_seed

hyperparameter_defaults = dict(
    batch_size=256,
    learning_rate=0.0005,
    epochs=2001,
    dimension=3,
    angle = 30,
    loss='btcvae',
    beta=50.0,
    img_id=0,
    random_seed=224
)
parser = argparse.ArgumentParser()
for key, value in hyperparameter_defaults.items():
    parser.add_argument(f'--{key}', default=value, type=type(value))
args = parser.parse_args()
config = args

set_seed(config.random_seed)
wandb.init(project="exps", config=config, group='transformation')

def train(dl, model, loss_f, epochs,log=True):
    opt = optim.AdamW(model.parameters(), config.learning_rate)
    for e in trange(epochs):
        storer = defaultdict(list)
        loss_set = []
        for img,label in dl:
            img = img.cuda()
            try:
                recon_batch, latent_dist, latent_sample = model(img)
                loss = loss_f(img, recon_batch, latent_dist, model.training,
                              storer, latent_sample=latent_sample)

                opt.zero_grad()
                loss.backward()
                opt.step()
            except ValueError:
                # for losses that use multiple optimizers (e.g. Factor)
                loss = loss_f.call_optimize(img, model, opt, storer)

            loss_set.append(loss.item())
            if e%100 ==0:
                points=latent_sample.detach().cpu().numpy()
                colors=label.detach().cpu().numpy()
                storer['pc']=wandb.Object3D(np.concatenate([points,colors],1))

                fig, axes = plt.subplots(2, 3, figsize=(15, 10))
                for i in range(2):
                    for j in range(3):
                        axes[i, j].scatter(points[:, j], colors[:, 1 + i])
                storer['correlated']=wandb.Image(fig)
                plt.close()


        for k, v in storer.items():
            if isinstance(v, list):
                storer[k] = np.mean(v)
        storer['loss'] = np.mean(loss_set)
        if e % 100 == 0:
            storer['recon']= wandb.Image(recon_batch[0, 0])
            storer['img']= wandb.Image(img[0, 0])

            model.cpu()
            model.eval()
            fig = plt_sample_traversal(dataset[20:21], model, 7, range(dim), r=2)
            storer['traversal']=fig
            model.cuda()
            model.train()
            plt.close()
        storer['epoch'] = e
        if log:
            wandb.log(storer, sync=False)
    return storer


# generate data
img_id = config.img_id
epochs = config.epochs
dim = config.dimension
beta = config.beta

img = torch.load(f'patterns/{img_id}.pat')
imgs, labels = data_generator.gen_rotation(img, img.shape)
angle = config.angle
dataset=[]
targets=[]
# 0
dataset.append(imgs[angle,:,0].reshape(-1,1,64,64))
# 1
dataset.append(imgs[angle,0,:].reshape(-1,1,64,64))
# 2
ds = [imgs[angle, i, i] for i in range(40)]
dataset.append(torch.stack(ds).reshape(-1, 1, 64, 64))

# 0
targets.append(labels[angle,:,0].reshape(-1,3))
# 1
targets.append(labels[angle,0,:].reshape(-1,3))
# 2
ds = [labels[angle, i, i] for i in range(40)]
targets.append(torch.stack(ds).reshape(-1, 3))

dataset = torch.cat(dataset)
targets = torch.cat(targets)
dl = DataLoader(list(zip(dataset,targets)),
                num_workers=0, batch_size=config.batch_size, shuffle=True)


vae = VAE((1, 64, 64), get_encoder('Burgess'), get_decoder('Burgess'), dim)
vae.cuda()
vae.train()

iterations = len(dl) * epochs
anneal = get_anneal('monotonic', iterations, beta, 40)
loss_f = get_loss_s(config.loss, anneal, 60, len(dl.dataset))
storer = train(dl, vae, loss_f, epochs)

vae.eval()
vae.cpu()
plt.show()
